{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "df840597-64ce-4834-852e-48ced451f69f"
      },
      "source": [
        "<a target=\"_blank\" href=\"https://colab.research.google.com/github/google-ai-edge/ai-edge-torch/blob/main/ai_edge_torch/generative/examples/t5/t5_conversion_colab.ipynb\">\n",
        "  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
        "</a>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "00e00b3b-d7ed-4e2e-815e-3addfc23c8f3"
      },
      "outputs": [],
      "source": [
        "# Copyright 2024 The AI Edge Torch Authors.\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License.\n",
        "# ==============================================================================\n",
        "# This is a simple colab showing how to re-author T5 (encoder-decoder) model,\n",
        "# convert and run in a colab python environment."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a9bdc007e6ce"
      },
      "source": [
        "Note: When running notebooks in this repository with Google Colab, some users may see\n",
        "the following warning message:\n",
        "\n",
        "![Colab warning](https://github.com/google-ai-edge/ai-edge-torch/blob/main/docs/data/colab_warning.jpg?raw=true)\n",
        "\n",
        "Please click `Restart Session` and run again."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a91d40b5-91f0-4c19-bdb4-a2f56fa1c5ff"
      },
      "outputs": [],
      "source": [
        "!pip install -r https://raw.githubusercontent.com/google-ai-edge/ai-edge-torch/main/requirements.txt\n",
        "!pip install ai-edge-torch-nightly"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WpEGGRs4FJo6"
      },
      "source": [
        "## Download model checkpoint\n",
        "First we download the T5 pytorch checkpoint from huggingface from https://huggingface.co/humarin/chatgpt_paraphraser_on_T5_base."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ivZdosGMowfl"
      },
      "outputs": [],
      "source": [
        "!curl -O -L https://huggingface.co/humarin/chatgpt_paraphraser_on_T5_base/resolve/main/pytorch_model.bin"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "R5aPvvxeF3bc"
      },
      "source": [
        "## T5 Model Authoring and Conversion\n",
        "Next, we import the T5 encoder/decoder implementation from `ai_edge_torch/generative/examples/t5`, and convert to TFLite with 2 signatures: `encode` and `decode`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "YqPe7B8hwGh2"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import torch\n",
        "\n",
        "import ai_edge_torch\n",
        "from ai_edge_torch.generative.examples.t5 import t5\n",
        "from ai_edge_torch.generative.quantize import quant_recipes\n",
        "\n",
        "\n",
        "def convert_t5_to_tflite_multisig(checkpoint_path: str):\n",
        "  config = t5.get_model_config_t5()\n",
        "  # Temporarily disable HLFB until custom op issue is fixed.\n",
        "  config.enable_hlfb = False\n",
        "  embedding_layer = torch.nn.Embedding(\n",
        "      config.vocab_size, config.embedding_dim, padding_idx=0\n",
        "  )\n",
        "  t5_encoder_model = t5.build_t5_encoder_model(config, embedding_layer, checkpoint_path)\n",
        "  t5_decoder_model = t5.build_t5_decoder_model(config, embedding_layer, checkpoint_path)\n",
        "\n",
        "  # encoder\n",
        "  seq_len = 512\n",
        "  prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.int)\n",
        "  prompt_e_token = [1, 2, 3, 4, 5, 6]\n",
        "  prefill_e_tokens[0, : len(prompt_e_token)] = torch.tensor(\n",
        "      prompt_e_token, dtype=torch.int\n",
        "  )\n",
        "  prefill_e_input_pos = torch.arange(0, seq_len, dtype=torch.int)\n",
        "  prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.int)\n",
        "  prompt_d_token = [1, 2, 3, 4, 5, 6]\n",
        "  prefill_d_tokens[0, : len(prompt_d_token)] = torch.tensor(\n",
        "      prompt_d_token, dtype=torch.int\n",
        "  )\n",
        "  prefill_d_input_pos = torch.arange(0, seq_len, dtype=torch.int)\n",
        "\n",
        "  # decoder\n",
        "  decode_token = torch.tensor([[1]], dtype=torch.int)\n",
        "  decode_input_pos = torch.tensor([0], dtype=torch.int)\n",
        "  decode_d_token = torch.tensor([[1]], dtype=torch.int)\n",
        "  decode_d_input_pos = torch.tensor([0], dtype=torch.int)\n",
        "\n",
        "  # Pad mask for self attention only on \"real\" tokens.\n",
        "  # Pad with `-inf` for any tokens indices that aren't desired.\n",
        "  pad_mask = torch.zeros([seq_len], dtype=torch.float32)\n",
        "  hidden_states = torch.zeros((1, 512, 768), dtype=torch.float32)\n",
        "  quant_config = quant_recipes.full_int8_dynamic_recipe()\n",
        "\n",
        "  edge_model = ai_edge_torch.signature(\n",
        "          'encode',\n",
        "          t5_encoder_model.eval(),\n",
        "          (\n",
        "              prefill_e_tokens,\n",
        "              prefill_e_input_pos,\n",
        "              pad_mask,\n",
        "          ),\n",
        "      ).signature(\n",
        "          'decode',\n",
        "          t5_decoder_model.eval(),\n",
        "          (\n",
        "              hidden_states,\n",
        "              decode_d_token,\n",
        "              decode_d_input_pos,\n",
        "              pad_mask,\n",
        "          ),\n",
        "      ).convert(quant_config=quant_config)\n",
        "\n",
        "  edge_model.export('t5_encode_decode_2_sigs.tflite')\n",
        "  return edge_model\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UlgHZtWIGhAc"
      },
      "source": [
        "Finally, we call the convert function, this might take a few minutes to finish."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nMCuhDTawncf"
      },
      "outputs": [],
      "source": [
        "print('converting T5 to tflite.')\n",
        "edge_model = convert_t5_to_tflite_multisig(\"/content/pytorch_model.bin\")"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "t5_conversion_colab.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
